#############################################################################################
# NKPH moment methods
#############################################################################################


###########################################################################

# moment and loss structs
struct NKPHMoments
    # combine exog/endog parameter dictionary
    # for easy parameter lookup and avoiding repeated computation
    x_dict::OrderedDict{Symbol,Float64}
    dx_dict::OrderedDict{Symbol,Float64}

    # eigenvalue decomposition objects
    M_eigvals::Vector{ComplexF64}
    M_eigvecs::Matrix{ComplexF64}
    M_eigvecs_inv::Matrix{ComplexF64}
    #Gamma_eigvals::Vector{ComplexF64}
    #Gamma_eigvecs::Matrix{ComplexF64}
    Gamma_eigvecs_inv::Matrix{ComplexF64}

    # schur decomp
    Gamma_R::Matrix{Float64}
    Gamma_Q::Matrix{Float64}

    # for computing expm derivs
    int_Phi_tau_arr::Array{ComplexF64, 3}
    Phi_tau_arr::Array{ComplexF64, 3}
    Phi_s_arr::Array{ComplexF64, 3}

    # covariance sandwich matrices
    Sig_oo::Matrix{Float64}
    # auto-covariance sandwich matrices
    e_s_Gamma_arr::Array{Float64, 3}
    e_s_GammaSig_oo_arr::Array{Float64, 3}

    # affine coeffs
    int_eM_arr::Array{Float64, 3}
    eM_arr::Array{Float64, 3}
    A_tau_arr::Matrix{Float64}
    A_tau_arr_til::Matrix{Float64}

    # vector maturities and auto-cov time differences
    tau_arr::Vector{Float64}
    s_auto_diffs_arr::Vector{Float64}
    # dictionaries for mapping tau/s_auto to indices
    tau_dict::Dict{Float64, Int}
    s_auto_dict::Dict{Float64, Int}
    # specific short-rate/long-rate maturity index
    tau_short_idx::Int
    tau_long_idx::Int

    N_tau::Int
    N_s_auto::Int

    # derivative objects
    dSig_oo::Array{Float64, 3}
    de_s_Gamma_arr::Array{Float64, 4}
    de_s_GammaSig_oo_arr::Array{Float64, 4}
    dA_tau_arr::Array{Float64, 3}
    dA_tau_arr_til::Array{Float64, 3}
    dOmega::Array{Float64, 3}

    # pointer to underlying solver
    nkphsolvr::NKPHSolver
end
function NKPHMoments(
        tau_short::Float64,
        tau_long::Float64,
        tau_arr::Vector{Float64},
        s_auto_diffs_arr::Vector{Float64},
        nkphsolvr::NKPHSolver
    )
    # initialize objects
    # pull out objects
    J, N = nkphsolvr.J, nkphsolvr.N
    N_params, N_params_all = nkphsolvr.N_params, nkphsolvr.N_params_all
    x_dict = deepcopy(nkphsolvr.x_dict)
    dx_dict = deepcopy(nkphsolvr.dx_dict)

    # placeholders
    # eigendecomps (required to be complex for stable typing)
    M_eigvals = zeros(ComplexF64, J)
    M_eigvecs = zeros(ComplexF64, J, J)
    M_eigvecs_inv = zeros(ComplexF64, J, J)
    Gamma_eigvecs_inv = zeros(ComplexF64, J, J)
    Gamma_R = zeros(J, J)
    Gamma_Q = zeros(J, J)
    Sig_oo = zeros(J, J)

    # tau objects
    N_tau = length(tau_arr)
    tau_dict = Dict{Float64, Int}(tau => i for (i, tau) in enumerate(tau_arr))
    tau_short_idx = tau_dict[tau_short]
    tau_long_idx = tau_dict[tau_long]
    A_tau_arr = zeros(J, N_tau)
    A_tau_arr_til = zeros(J, N_tau)
    int_eM_arr = zeros(J, J, N_tau)
    eM_arr = zeros(J, J, N_tau)
    int_Phi_tau_arr = zeros(ComplexF64, J, J, N_tau)
    Phi_tau_arr = zeros(ComplexF64, J, J, N_tau)

    # s objects
    N_s_auto = length(s_auto_diffs_arr)
    s_auto_dict = Dict{Float64, Int}(s_auto => i for (i, s_auto) in enumerate(s_auto_diffs_arr))
    e_s_Gamma_arr = zeros(J, J, N_s_auto)
    e_s_GammaSig_oo_arr = zeros(J, J, N_s_auto)
    Phi_s_arr = zeros(ComplexF64, J, J, N_s_auto)

    # deriv objects
    dSig_oo = zeros(J, J, N_params_all)
    de_s_Gamma_arr = zeros(J, J, N_s_auto, N_params_all)
    de_s_GammaSig_oo_arr = zeros(J, J, N_s_auto, N_params_all)

    dA_tau_arr = zeros(J, N_tau, N_params_all)
    dA_tau_arr_til = zeros(J, N_tau, N_params_all)

    dOmega = zeros(N-J, J, N_params_all)

    return NKPHMoments(x_dict, dx_dict,
        M_eigvals, M_eigvecs, M_eigvecs_inv, Gamma_eigvecs_inv,
        Gamma_R, Gamma_Q,
        int_Phi_tau_arr, Phi_tau_arr, Phi_s_arr,
        Sig_oo, e_s_Gamma_arr, e_s_GammaSig_oo_arr,
        int_eM_arr, eM_arr, A_tau_arr, A_tau_arr_til,
        tau_arr, s_auto_diffs_arr, tau_dict, s_auto_dict, tau_short_idx, tau_long_idx,
        N_tau, N_s_auto,
        dSig_oo, de_s_Gamma_arr, de_s_GammaSig_oo_arr,
        dA_tau_arr, dA_tau_arr_til, dOmega,
        nkphsolvr
    )
end

function compute_nkphmmts_objects!(nkphmmts::NKPHMoments)::Nothing
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    J = nkphsolvr.J
    M, Sigma = nkphsolvr.M, nkphsolvr.Sigma
    Gamma = nkphsolvr.sdyn.Gamma
    ei_vec, ed_vec = nkphsolvr.ei_vec, nkphsolvr.ed_vec
    M_eigvals, M_eigvecs = nkphmmts.M_eigvals, nkphmmts.M_eigvecs
    M_eigvecs_inv, Gamma_eigvecs_inv = nkphmmts.M_eigvecs_inv, nkphmmts.Gamma_eigvecs_inv
    Gamma_R, Gamma_Q = nkphmmts.Gamma_R, nkphmmts.Gamma_Q
    Sig_oo, e_s_Gamma_arr, e_s_GammaSig_oo_arr = nkphmmts.Sig_oo,
        nkphmmts.e_s_Gamma_arr, nkphmmts.e_s_GammaSig_oo_arr
    int_eM_arr, eM_arr = nkphmmts.int_eM_arr, nkphmmts.eM_arr
    A_tau_arr, A_tau_arr_til = nkphmmts.A_tau_arr, nkphmmts.A_tau_arr_til
    int_Phi_tau_arr, Phi_tau_arr = nkphmmts.int_Phi_tau_arr, nkphmmts.Phi_tau_arr
    Phi_s_arr = nkphmmts.Phi_s_arr

    # update
    # eigendecomps (required to be complex for stable typing)
    M_eig = eigen(M)
    M_eigvals[:] = convert(Vector{ComplexF64}, M_eig.values)
    M_eigvecs[:, :] = convert(Matrix{ComplexF64}, M_eig.vectors)
    M_eigvecs_inv[:, :] = inv(M_eigvecs)
    # Gamma eigenvalues already computed in state decomp
    Gamma_eigvals = nkphsolvr.sdyn.Lam1_diag
    Gamma_eigvecs = nkphsolvr.sdyn.W11
    Gamma_eigvecs_inv[:, :] = inv(Gamma_eigvecs)

    # schur decomp
    Gamma_R[:, :], Gamma_Q[:, :] = schur(Gamma)
    # long-run variance
    Sig_oo[:, :] = lyap(Gamma_R, Gamma_Q, -Sigma)

    # array of int_eM and other objects as a function of maturity tau
    for (i, tau) in enumerate(nkphmmts.tau_arr)
        # exp matrix terms
        int_eM_tau = @view int_eM_arr[:, :, i]
        eM_tau = @view eM_arr[:, :, i]
        A_tau = @view A_tau_arr[:, i]
        A_tau_til = @view A_tau_arr_til[:, i]
        int_eM_tau[:, :] = calc_exp_matrix_tau_integral(tau, M_eigvals, M_eigvecs)
        eM_tau[:, :] = calc_exp_matrix_tau(tau, M_eigvals, M_eigvecs)
        # coeffs
        A_tau[:] = int_eM_tau * ei_vec
        A_tau_til[:] = A_tau - eM_tau * ed_vec

        # derivative Phi matrices (inplace)
        int_Phi_tau = @view int_Phi_tau_arr[:, :, i]
        Phi_tau = @view Phi_tau_arr[:, :, i]
        calc_dexp_int_Phi_mat!(int_Phi_tau, tau, M_eigvals)
        calc_dexp_Phi_mat!(Phi_tau, tau, M_eigvals)
    end

    # autocovariance matrices
    for (i, s_auto) in enumerate(nkphmmts.s_auto_diffs_arr)
        # note: s_auto_diffs_arr[1] == 0
        if i==1
            e_s_GammaSig_oo_arr[:, :, i] = Sig_oo
            e_s_Gamma_arr[:, :, i] = Matrix{Float64}(I, J, J)
            Phi_s_arr[:, :, i] = zeros(ComplexF64, J, J)
        else
            e_s_Gamma = @view e_s_Gamma_arr[:, :, i]
            e_s_Gamma[:, :] = calc_exp_matrix_tau(s_auto, Gamma_eigvals, Gamma_eigvecs)
            e_s_GammaSig_oo_arr[:, :, i] = e_s_Gamma * Sig_oo
            # derivative Phi matrices (inplace)
            Phi_s = @view Phi_s_arr[:, :, i]
            calc_dexp_Phi_mat!(Phi_s, s_auto, Gamma_eigvals)
        end
    end
    return nothing
end

# wrapper
function update_nkphmmts!(x_arr::Vector{Float64}, nkphmmts::NKPHMoments)::Nothing
    # update underlying model and compute moment objects
    update_nkphsolvr!(x_arr, nkphmmts.nkphsolvr)
    # avoid repeated computations
    if x_arr != nkphmmts.x_dict.vals
        nkphmmts.x_dict.vals[:] = x_arr
        compute_nkphmmts_objects!(nkphmmts)
    end
    return nothing
end


#####################
# deriv functions
function deriv_nkphmmts_objects_wrtM!(nkphmmts::NKPHMoments, j_idx::Int, k_idx::Int)::Nothing
    # compute derivatives of affine coeffs wrt m_{j,k} (generic)
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    J, N_params, N_tau = nkphsolvr.J, nkphsolvr.N_params, nkphmmts.N_tau
    ei_vec, ed_vec = nkphsolvr.ei_vec, nkphsolvr.ed_vec
    int_Phi_tau_arr, Phi_tau_arr = nkphmmts.int_Phi_tau_arr, nkphmmts.Phi_tau_arr
    M_eigvecs, M_eigvecs_inv = nkphmmts.M_eigvecs, nkphmmts.M_eigvecs_inv
    # deriv objects
    dA_tau_arr, dA_tau_arr_til = nkphmmts.dA_tau_arr, nkphmmts.dA_tau_arr_til
    # update (inplace)
    deriv_idx = _m_idx_to_deriv_idx(nkphsolvr, j_idx, k_idx)
    for i=1:N_tau
        # integral matrix deriv
        int_Phi_tau = @view int_Phi_tau_arr[:, :, i]
        Phi_tau = @view Phi_tau_arr[:, :, i]
        dA_tau = @view dA_tau_arr[:, i, deriv_idx]
        dA_tau_til = @view dA_tau_arr_til[:, i, deriv_idx]
        dint_eM = deriv_exp_matrix_tau(int_Phi_tau, M_eigvecs, M_eigvecs_inv, j_idx, k_idx)
        deM = deriv_exp_matrix_tau(Phi_tau, M_eigvecs, M_eigvecs_inv, j_idx, k_idx)
        # A/A_til derivs
        dA_tau[:] = dint_eM * ei_vec
        dA_tau_til[:] = dA_tau - deM * ed_vec
    end
    return nothing
end

# helper function
function _deriv_autocov_mats!(nkphmmts::NKPHMoments,
        dGamma::AbstractMatrix, dSigma::AbstractMatrix,
        deriv_idx::Int)::Nothing
    # compute derivatives of Sig_oo/e_s_Gamma_arr wrt dGamma, dSigma (given idx for saving deriv)
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    Sig_oo, Gamma_R, Gamma_Q = nkphmmts.Sig_oo, nkphmmts.Gamma_R, nkphmmts.Gamma_Q
    N_s_auto, Phi_s_arr = nkphmmts.N_s_auto, nkphmmts.Phi_s_arr
    Gamma_eigvecs = nkphsolvr.sdyn.W11

    # compute dcov mat terms (inplace)
    dSig_oo = @view nkphmmts.dSig_oo[:, :, deriv_idx]
    dGamma_Sig_oo = dGamma * Sig_oo
    dSig_oo[:, :] = lyap(Gamma_R, Gamma_Q, -dSigma + dGamma_Sig_oo + dGamma_Sig_oo' )
    for i=1:N_s_auto
        Phi_s = @view Phi_s_arr[:, :, i]
        e_s_Gamma = @view nkphmmts.e_s_Gamma_arr[:, :, i]
        de_s_Gamma = @view nkphmmts.de_s_Gamma_arr[:, :, i, deriv_idx]
        de_s_GammaSig_oo = @view nkphmmts.de_s_GammaSig_oo_arr[:, :, i, deriv_idx]
        de_s_Gamma[:, :] = deriv_exp_matrix_tau(Phi_s, Gamma_eigvecs, dGamma)
        de_s_GammaSig_oo[:, :] = de_s_Gamma * Sig_oo + e_s_Gamma * dSig_oo
    end
    return nothing
end

function deriv_nkphmmts_objects_wrtAR!(nkphmmts::NKPHMoments, j_idx)::Nothing
    # compute derivatives of Sig_oo/e_s_Gamma_arr/Omega wrt AR_hat_{j}
    nkphsolvr = nkphmmts.nkphsolvr
    # pull out objects
    J, N_params, N_tau = nkphsolvr.J, nkphsolvr.N_params, nkphmmts.N_tau
    sdyn = nkphsolvr.sdyn
    sigma = nkphsolvr.sigma
    # deriv objects (inplace with correct deriv idx)
    deriv_idx = _AR_idx_to_deriv_idx(nkphsolvr, j_idx)
    dUpsilon = @view nkphsolvr.dUpsilon[:, :, deriv_idx]
    dGamma = @view nkphsolvr.dGamma[:, :, deriv_idx]
    dOmega = @view nkphmmts.dOmega[:, :, deriv_idx]

    # jump dynamics Omega deriv
    deriv_jump_dynamics_wrtUpsilon!(dOmega, sdyn, dUpsilon)
    # state dynamics Gamma deriv
    deriv_state_dynamics_wrtUpsilon!(dGamma, sdyn, dUpsilon)
    # deriv cov mats given dGamma, (dSigma = 0 wrt AR_hat)
    dSigma = zero(dGamma)
    _deriv_autocov_mats!(nkphmmts, dGamma, dSigma, deriv_idx)
    return nothing
end

function deriv_nkphmmts_objects_wrtparam!(nkphmmts::NKPHMoments, param_idx::Int)::Nothing
    # compute derivatives of Sig_oo/e_s_Gamma_arr/Omega wrt AR_hat_{j}
    nkphsolvr = nkphmmts.nkphsolvr
    # pull out objects
    J, N = nkphsolvr.J, nkphsolvr.N
    sigma = nkphsolvr.sigma
    sdyn = nkphsolvr.sdyn
    # deriv objects
    dsigma = @view nkphsolvr.dsigma[:, :, param_idx]
    dSigma = @view nkphsolvr.dSigma[:, :, param_idx]
    dUpsilon = @view nkphsolvr.dUpsilon[:, :, param_idx]
    dGamma = @view nkphsolvr.dGamma[:, :, param_idx]
    dOmega = @view nkphmmts.dOmega[:, :, param_idx]

    # check which were changed and compute derivs
    pmap = nkphsolvr.param_model_map[param_idx]
    if pmap[:Upsilon]
        # jump dynamics Omega deriv
        deriv_jump_dynamics_wrtUpsilon!(dOmega, sdyn, dUpsilon)
        # compute state dynamics Gamma deriv
        deriv_state_dynamics_wrtUpsilon!(dGamma, sdyn, dUpsilon)
    end
    # deriv cov mats given dGamma, dSigma
    if pmap[:Upsilon] || pmap[:sigma]
        _deriv_autocov_mats!(nkphmmts, dGamma, dSigma, param_idx)
    end
    return nothing
end

# wrapper
function deriv_nkphmmts!(nkphmmts::NKPHMoments)::Nothing
    # compute all nkphmmts derivative objects
    # derivs of underlying solver objects
    nkphsolvr = nkphmmts.nkphsolvr
    deriv_nkphsolvr!(nkphsolvr)
    # avoid repeated computations
    if nkphmmts.x_dict.vals != nkphmmts.dx_dict.vals
        nkphmmts.dx_dict.vals[:] = nkphmmts.x_dict.vals
        N_params = nkphsolvr.N_params
        J = nkphsolvr.J
        # derivs wrt params (model-specific)
        for param_idx=1:N_params
            deriv_nkphmmts_objects_wrtparam!(nkphmmts, param_idx)
        end
        # derivs wrt AR_hat (model-specific)
        for j_idx=1:J
            deriv_nkphmmts_objects_wrtAR!(nkphmmts, j_idx)
        end
        # derivs wrt M (loop order corresponds to vec(M) order)
        for k_idx=1:J
            for j_idx=1:J
                deriv_nkphmmts_objects_wrtM!(nkphmmts, j_idx, k_idx)
            end
        end
    end
    return nothing
end



############################################################
# moment vars
# second moment variable structure
struct MomentVariable
    # moment variable:
    # sum_i V_i' * y_{t+s_i}
    # V_mat = [V_1 ... V_Ns]; JxNs matrix
    V_mat::Matrix{Float64}
    # sV_auto_arr = [s_1 ... s_Ns]; Ns vector
    sV_auto_arr::Vector{Float64}

    # for mapping model into variable vectors
    variable_name::Symbol
    variable_type::Symbol
    is_diff::Bool
    tau::Float64

    # derivative objects (final dimension N_params_all corresponding to exog/endog params)
    dV_mat::Array{Float64, 3}
end
function _parse_variable_time_diffs(variable_name_str::String,
        s_auto_diffs_arr)::Tuple{Symbol, Bool, Vector{Float64}}
    # parse variable name for time difference
    # difference type
    is_diff = false
    s = 0.0
    for s_auto in s_auto_diffs_arr
        # time diffs in months
        diff_type = "D" * string(Int(12*s_auto)) * "_"
        diff_idxs = findfirst(diff_type, variable_name_str)
        if diff_idxs !== nothing
            is_diff = true
            s = s_auto
            variable_name_str = variable_name_str[diff_idxs[end]+1:end]
            break
        end
    end
    if is_diff
        sV_auto_arr = [0.0, s]
    else
        sV_auto_arr = [0.0]
    end
    variable_type = Symbol(variable_name_str)
    return variable_type, is_diff, sV_auto_arr
end

# empty initialization with zero vecs
function MomentVariable(variable_name::Symbol, tau::Float64,
        nkphmmts::NKPHMoments)
    s_auto_diffs_arr = nkphmmts.s_auto_diffs_arr
    nkphsolvr = nkphmmts.nkphsolvr
    N_params, J, N_params_all = nkphsolvr.N_params, nkphsolvr.J, nkphsolvr.N_params_all
    # parse time differences etc from variable name
    variable_name_str = String(variable_name)
    variable_type, is_diff, sV_auto_arr = _parse_variable_time_diffs(variable_name_str,
        s_auto_diffs_arr)
    Ns = length(sV_auto_arr)
    V_mat = zeros(J, Ns)
    dV_mat = zeros(J, Ns, N_params_all)
    return MomentVariable(V_mat, sV_auto_arr,
        variable_name, variable_type, is_diff, tau,
        dV_mat
    )
end

function _get_mmtvar_vec(
        variable_type::Symbol,
        tau::Float64,
        tau_idx::Int,
        tau_short::Float64,
        tau_short_idx::Int,
        tau_long::Float64,
        tau_long_idx::Int,
        A_tau_arr::AbstractMatrix,
        A_tau_arr_til::AbstractMatrix,
        Omega::AbstractMatrix,
        ei_vec::Vector{Float64},
        ed_vec::Vector{Float64}, 
        #b_shock_val::Float64
    )::Vector{Float64}
    # function to get moment vectors from arrays (or derivatives)
    if variable_type == :pi
        _vec = Omega[1, :]
    elseif variable_type == :x
        _vec = Omega[2, :]
    elseif variable_type == :ishort
        _vec = A_tau_arr[:, tau_short_idx] ./ tau_short
    elseif variable_type == :ishorttil
        _vec = (A_tau_arr_til[:, tau_short_idx] + ed_vec) ./ tau_short
    elseif variable_type == :ishort_diff
        _vec = (A_tau_arr_til[:, tau_short_idx] + ed_vec - A_tau_arr[:, tau_short_idx]) ./ tau_short
    elseif variable_type == :ilong
        _vec = A_tau_arr[:, tau_long_idx] ./ tau_long
    elseif variable_type == :ilongtil
        _vec = (A_tau_arr_til[:, tau_long_idx] + ed_vec) ./ tau_long
    elseif variable_type == :ilong_diff
        _vec = (A_tau_arr_til[:, tau_long_idx] + ed_vec - A_tau_arr[:, tau_long_idx]) ./ tau_long
    elseif variable_type == :y
        _vec = A_tau_arr[:, tau_idx] ./ tau
    elseif variable_type == :ytil
        _vec = (A_tau_arr_til[:, tau_idx] + ed_vec)./ tau
    elseif variable_type == :y_diff
        _vec = (A_tau_arr_til[:, tau_idx] + ed_vec - A_tau_arr[:, tau_idx])./ tau
    elseif variable_type == :y_slope
        _vec = (A_tau_arr[:, tau_idx] ./ tau) - (A_tau_arr[:, tau_short_idx])./ tau_short
    elseif variable_type == :ytil_slope
        _vec = (A_tau_arr_til[:, tau_idx] ./ tau) - (A_tau_arr[:, tau_short_idx])./ tau_short
    elseif variable_type == :b_shock
        _vec = ones(length(ei_vec))
    else
        # note: not all variable types supported
        throw(DomainError(variable_type, "unsupported variable_type type"))
    end
    return _vec
end


function update_mmtvar_vecs!(mmtvar::MomentVariable, nkphmmts::NKPHMoments)::Nothing
    # update vector mat for mmtvar
    # pull out objects
    tau_short_idx = nkphmmts.tau_short_idx
    tau_short = nkphmmts.tau_arr[tau_short_idx]
    tau_long_idx = nkphmmts.tau_long_idx
    tau_long = nkphmmts.tau_arr[tau_long_idx]
    tau_dict = nkphmmts.tau_dict
    # pull out coeff data
    A_tau_arr = nkphmmts.A_tau_arr
    A_tau_arr_til = nkphmmts.A_tau_arr_til
    Omega = nkphmmts.nkphsolvr.sdyn.Omega
    ei_vec, ed_vec = nkphmmts.nkphsolvr.ei_vec, nkphmmts.nkphsolvr.ed_vec

    tau, is_diff =  mmtvar.tau, mmtvar.is_diff
    if tau > 0
        tau_idx = tau_dict[tau]
    else
        tau_idx = 0
    end
    variable_type = mmtvar.variable_type
    _vec = _get_mmtvar_vec(variable_type, tau, tau_idx,
        tau_short, tau_short_idx, tau_long, tau_long_idx,
        A_tau_arr, A_tau_arr_til, Omega,
        ei_vec, ed_vec
    )
    # update V matrix
    if is_diff
        # difference var
        mmtvar.V_mat[:, 1] = -_vec
        mmtvar.V_mat[:, 2] = _vec
    else
        # simple level var
        mmtvar.V_mat[:, 1] = _vec
    end
    return nothing
end


function deriv_mmtvar!(mmtvar::MomentVariable, nkphmmts::NKPHMoments)::Nothing
    # compute derivative of mmtvar vectors
    # pull out objects
    J = nkphmmts.nkphsolvr.J
    tau_short_idx = nkphmmts.tau_short_idx
    tau_short = nkphmmts.tau_arr[tau_short_idx]
    tau_long_idx = nkphmmts.tau_long_idx
    tau_long = nkphmmts.tau_arr[tau_long_idx]
    tau_dict = nkphmmts.tau_dict

    tau, is_diff =  mmtvar.tau, mmtvar.is_diff
    if tau > 0
        tau_idx = tau_dict[tau]
    else
        tau_idx = 0
    end
    variable_type = mmtvar.variable_type

    nkphsolvr = nkphmmts.nkphsolvr
    N_params_all = nkphsolvr.N_params_all
    # derivs
    for idx=1:N_params_all
        # pull out A coeff deriv objects (views, using correct deriv idx)
        dA_tau_arr = @view nkphmmts.dA_tau_arr[:, :, idx]
        dA_tau_arr_til = @view nkphmmts.dA_tau_arr_til[:, :, idx]
        dOmega = @view nkphmmts.dOmega[:, :, idx]
        # derivative vector (note: no change in ed_vec)
        _dvec = _get_mmtvar_vec(variable_type, tau, tau_idx,
            tau_short, tau_short_idx, tau_long, tau_long_idx,
            dA_tau_arr, dA_tau_arr_til, dOmega, zeros(J), zeros(J)
        )
        # update dV matrix
        if is_diff
            # difference var
            mmtvar.dV_mat[:, 1, idx] = -_dvec
            mmtvar.dV_mat[:, 2, idx] = _dvec
        else
            # simple level var
            mmtvar.dV_mat[:, 1, idx] = _dvec
        end
    end
    return nothing
end


####################################################################################
# generic methods for computing second moments

# unconditional second moments
function _calc_cov_moment(
        V_matL::AbstractMatrix,
        V_matR::AbstractMatrix,
        sV_auto_arrL::Vector{Float64},
        sV_auto_arrR::Vector{Float64},
        eS_arr::AbstractArray,
        s_auto_dict::Dict{Float64, Int}
    )::Float64
    # compute covariance from moment variables
    val = 0.0
    NsL = length(sV_auto_arrL)
    NsR = length(sV_auto_arrR)
    for idx_L=1:NsL
        s_autoL = sV_auto_arrL[idx_L]
        for idx_R=1:NsR
            # get sandwich matrix
            s_autoR = sV_auto_arrR[idx_R]
            s_auto_diff = s_autoL - s_autoR
            use_tranpose = s_auto_diff<0
            S = eS_arr[:, :, s_auto_dict[abs(s_auto_diff)]]
            # add to covariance val
            if use_tranpose
                val += V_matL[:, idx_L]' * S' * V_matR[:, idx_R]
            else
                val += V_matL[:, idx_L]' * S * V_matR[:, idx_R]
            end
        end
    end
    return val
end

# wrapper
function calc_cov_moment(mmtvarL::MomentVariable, mmtvarR::MomentVariable,
        nkphmmts::NKPHMoments)::Float64
    # compute cov moment from mmtvars
    # pull out objects
    V_matL, sV_auto_arrL = mmtvarL.V_mat, mmtvarL.sV_auto_arr
    V_matR, sV_auto_arrR = mmtvarR.V_mat, mmtvarR.sV_auto_arr
    eS_arr, s_auto_dict = nkphmmts.e_s_GammaSig_oo_arr, nkphmmts.s_auto_dict
    return _calc_cov_moment(V_matL, V_matR, sV_auto_arrL, sV_auto_arrR, eS_arr, s_auto_dict)
end

function deriv_cov_moment!(dval_arr::Vector{Float64},
        mmtvarL::MomentVariable, mmtvarR::MomentVariable,
        nkphmmts::NKPHMoments)::Nothing
    # deriv of cov moment from mmtvars (inplace)
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    N_params_all = nkphsolvr.N_params_all
    V_matL, sV_auto_arrL = mmtvarL.V_mat, mmtvarL.sV_auto_arr
    V_matR, sV_auto_arrR = mmtvarR.V_mat, mmtvarR.sV_auto_arr
    eS_arr, s_auto_dict = nkphmmts.e_s_GammaSig_oo_arr, nkphmmts.s_auto_dict

    # derivs
    for idx=1:N_params_all
        deS_arr = @view nkphmmts.de_s_GammaSig_oo_arr[:, :, :, idx]
        dV_matL = @view mmtvarL.dV_mat[:, :, idx]
        dV_matR = @view mmtvarR.dV_mat[:, :, idx]
        # compute iteratively
        dval = _calc_cov_moment(V_matL, V_matR, sV_auto_arrL, sV_auto_arrR,
            deS_arr, s_auto_dict)
        dval += _calc_cov_moment(dV_matL, V_matR, sV_auto_arrL, sV_auto_arrR,
            eS_arr, s_auto_dict)
        dval += _calc_cov_moment(V_matL, dV_matR, sV_auto_arrL, sV_auto_arrR,
            eS_arr, s_auto_dict)
        dval_arr[idx] = dval
    end
    return nothing
end


# conditional second moments
function calc_conditional_coeff_moment(mmtvarL::MomentVariable, mmtvarR::MomentVariable,
        idx_condcoeff::Int,
        nkphmmts::NKPHMoments)::Float64
    # compute 'conditional' regression coefficient
    # note: only defined for 'vector' moment variables
    vidx_L = mmtvarL.V_mat[idx_condcoeff, 1]
    vidx_R = mmtvarR.V_mat[idx_condcoeff, 1]
    return vidx_L / vidx_R
end

function deriv_conditional_coeff_moment!(dval_arr::Vector{Float64},
        mmtvarL::MomentVariable, mmtvarR::MomentVariable,
        idx_condcoeff::Int,
        nkphmmts::NKPHMoments)::Nothing
    # deriv of 'conditional' regression coefficient from mmtvars (inplace)
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    N_params_all = nkphsolvr.N_params_all
    vidx_L = mmtvarL.V_mat[idx_condcoeff, 1]
    vidx_R = mmtvarR.V_mat[idx_condcoeff, 1]

    # derivs
    for idx=1:N_params_all
        dvidx_L = mmtvarL.dV_mat[idx_condcoeff, 1, idx]
        dvidx_R = mmtvarR.dV_mat[idx_condcoeff, 1, idx]
        dval_numer = dvidx_L * vidx_R - vidx_L * dvidx_R
        dval_denom = vidx_R^2
        dval_arr[idx] = dval_numer / dval_denom
    end
    return nothing
end



####################################
# moment loss object
struct MomentLoss
    # target and model, and loss val
    loss_type::Symbol
    # loss weight
    wgt::Float64
    # store in 1-dim mutable arrays
    bhat::Vector{Float64}
    loss::Vector{Float64}
    b::Vector{Float64}
    b_diff::Vector{Float64}
    # cov(L,R), var(L), var(R); for corr/coeff type moment objects
    covLR::Vector{Float64}
    varL::Vector{Float64}
    varR::Vector{Float64}

    # for 'conditional' coeffs
    idx_condcoeff::Int

    # derivatives objects
    dloss::Vector{Float64}
    db::Vector{Float64}
    dcovLR::Vector{Float64}
    dvL::Vector{Float64}
    dvR::Vector{Float64}

    # auxiliary values for computing derivs (loss_type-specific)
    #_aux_vals::Vector{Float64}

    # pointer to moment variables
    mmtvarL::MomentVariable
    mmtvarR::MomentVariable
end
function MomentLoss(loss_type::Symbol, wgt::Float64, bhat::Float64, idx_condcoeff::Int,
        mmtvarL::MomentVariable, mmtvarR::MomentVariable)
    # placeholders
    N_params_all = size(mmtvarL.dV_mat, 3)
    dloss = zeros(N_params_all)
    db = zeros(N_params_all)
    covLR = zeros(N_params_all)
    dvL = zeros(N_params_all)
    dvR = zeros(N_params_all)
    return MomentLoss(loss_type, wgt,
        [bhat], [1.0], [0.0], [0.0],
        [0.0], [0.0], [0.0],
        idx_condcoeff,
        dloss, db, covLR, dvL, dvR,
        mmtvarL, mmtvarR
    )
end

##############
# wrappers
function calc_cov_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    mmtloss.b[1] = calc_cov_moment(mmtvarL, mmtvarR, nkphmmts)
    # do not need cLR, vL, vR objects for cov moment
    return nothing
end

function calc_sd_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvar = mmtloss.mmtvarL
    mmtloss.varL[1] = calc_cov_moment(mmtvar, mmtvar, nkphmmts)
    # use abs to avoid domain errors; bad Gamma eigvals caught elsewhere
    mmtloss.b[1] = sqrt(abs(mmtloss.varL[1]))
    # do not need cLR, vR objects for cov moment
    return nothing
end

function calc_corr_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    # save cLR, vL, vR objects for corr moment (for derivs)
    mmtloss.covLR[1] = calc_cov_moment(mmtvarL, mmtvarR, nkphmmts)
    mmtloss.varL[1] = calc_cov_moment(mmtvarL, mmtvarL, nkphmmts)
    mmtloss.varR[1] = calc_cov_moment(mmtvarR, mmtvarR, nkphmmts)
    val_denomL = sqrt(abs(mmtloss.varL[1]))
    val_denomR = sqrt(abs(mmtloss.varR[1]))
    mmtloss.b[1] = mmtloss.covLR[1] / (val_denomL * val_denomR)
    return nothing
end

# "unconditional" regression coeff
function calc_coeff_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    # save cLR, vR objects for coeff moment (for derivs)
    mmtloss.covLR[1] = calc_cov_moment(mmtvarL, mmtvarR, nkphmmts)
    mmtloss.varR[1] = calc_cov_moment(mmtvarR, mmtvarR, nkphmmts)
    mmtloss.b[1] = mmtloss.covLR[1] / mmtloss.varR[1]
    return nothing
end

# "conditional" regression coeff
function calc_conditional_coeff_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    idx_condcoeff = mmtloss.idx_condcoeff
    mmtloss.b[1] = calc_conditional_coeff_moment(mmtvarL, mmtvarR, idx_condcoeff, nkphmmts)
    # do not need cLR, vL, vR objects for cond coeff moment
    return nothing
end



##############
# derivs
function deriv_cov_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    deriv_cov_moment!(mmtloss.db, mmtvarL, mmtvarR, nkphmmts)
    return nothing
end

function deriv_sd_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvar = mmtloss.mmtvarL
    deriv_cov_moment!(mmtloss.dvL, mmtvar, mmtvar, nkphmmts)
    mmtloss.db[:] = mmtloss.dvL ./ (2 * mmtloss.b[1])
    return nothing
end

function deriv_corr_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    # compute cov/var derivs
    deriv_cov_moment!(mmtloss.dcovLR, mmtvarL, mmtvarR, nkphmmts)
    deriv_cov_moment!(mmtloss.dvL, mmtvarL, mmtvarL, nkphmmts)
    deriv_cov_moment!(mmtloss.dvR, mmtvarR, mmtvarR, nkphmmts)
    # get sd derivs
    val_denomL = sqrt(abs(mmtloss.varL[1]))
    val_denomR = sqrt(abs(mmtloss.varR[1]))
    dsdL = mmtloss.dvL ./ (2 * val_denomL)
    dsdR = mmtloss.dvR ./ (2 * val_denomR)
    # combined deriv
    mmtloss.db[:] = mmtloss.dcovLR ./ (val_denomL * val_denomR)
    mmtloss.db[:] -= dsdL .* (mmtloss.covLR[1] / (val_denomL^2 * val_denomR))
    mmtloss.db[:] -= dsdR .* (mmtloss.covLR[1] / (val_denomL * val_denomR^2))
    return nothing
end

function deriv_coeff_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    # compute cov/var derivs
    deriv_cov_moment!(mmtloss.dcovLR, mmtvarL, mmtvarR, nkphmmts)
    deriv_cov_moment!(mmtloss.dvR, mmtvarR, mmtvarR, nkphmmts)
    # combined deriv
    mmtloss.db[:] = mmtloss.dcovLR ./ mmtloss.varR[1]
    mmtloss.db[:] -= mmtloss.dvR .* (mmtloss.covLR[1] / (mmtloss.varR[1]^2))
    return nothing
end

function deriv_conditional_coeff_moment!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    mmtvarL, mmtvarR = mmtloss.mmtvarL, mmtloss.mmtvarR
    idx_condcoeff = mmtloss.idx_condcoeff
    deriv_conditional_coeff_moment!(mmtloss.db, mmtvarL, mmtvarR, idx_condcoeff, nkphmmts)
    return nothing
end


function set_loss_val!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    if mmtloss.loss_type == :cov
        calc_cov_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :sd
        calc_sd_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :corr
        calc_corr_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :coeff
        calc_coeff_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :cond_coeff
        calc_conditional_coeff_moment!(mmtloss, nkphmmts)
    else
        throw(DomainError(mmtloss.loss_type, "unsupported loss_type"))
    end
    mmtloss.b_diff[1] = mmtloss.b[1] - mmtloss.bhat[1]
    mmtloss.loss[1] = mmtloss.wgt * mmtloss.b_diff[1]^2
    return nothing
end

function deriv_loss_val!(mmtloss::MomentLoss, nkphmmts::NKPHMoments)::Nothing
    if mmtloss.loss_type == :cov
        deriv_cov_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :sd
        deriv_sd_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :corr
        deriv_corr_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :coeff
        deriv_coeff_moment!(mmtloss, nkphmmts)
    elseif mmtloss.loss_type == :cond_coeff
        deriv_conditional_coeff_moment!(mmtloss, nkphmmts)
    else
        throw(DomainError(mmtloss.loss_type, "unsupported loss_type"))
    end
    mmtloss.dloss[:] = (2*mmtloss.wgt) .* mmtloss.b_diff .* mmtloss.db
    return nothing
end

# broadcasting for loss arrs
function calc_loss(mmtloss_arr::Vector{MomentLoss})::Float64
    loss_val = 0.0
    for mmtloss in mmtloss_arr
        loss_val += mmtloss.loss[1]
    end
    return loss_val
end

function deriv_loss!(dloss_arr::Vector{Float64},
        mmtloss_arr::Vector{MomentLoss})::Nothing
    for mmtloss in mmtloss_arr
        dloss_arr[:] += mmtloss.dloss
    end
    return nothing
end



########################################################################
# combined target struct
struct NKPHTargets
    # combine exog/endog parameter dictionary
    # for easy parameter lookup and avoiding repeated computation
    x_dict::OrderedDict{Symbol,Float64}
    dx_dict::OrderedDict{Symbol,Float64}

    # moment variables
    mmtvars_arr::Vector{MomentVariable}
    # loss objects (based on mmtvars)
    mmtloss_arr::Vector{MomentLoss}

    # for continuation targets
    bhat_all0_arr::Vector{Float64}
    bhat_all1_arr::Vector{Float64}

    # moment/solver object
    nkphmmts::NKPHMoments

    # loss value, and eigenvalue constraint (store in 1-dim mutable arrays)
    loss_val::Vector{Float64}
    Upsilon_min_pos_eigval::Vector{Float64}
    Upsilon_max_neg_eigval::Vector{Float64}
    M_min_eigval::Vector{Float64}

    # for derivatives: N_params_all vectors
    dloss_val::Vector{Float64}
    dUpsilon_min_pos_eigval::Vector{Float64}
    dUpsilon_max_neg_eigval::Vector{Float64}
    dM_min_eigval::Vector{Float64}
end
# initialize given number of mmtvars and mmtloss objects
function NKPHTargets(
        mmtvars_arr::Vector{MomentVariable},
        mmtloss_arr::Vector{MomentLoss},
        nkphmmts::NKPHMoments
    )
    # pull out objects
    nkphsolvr = nkphmmts.nkphsolvr
    N_params_all = nkphmmts.nkphsolvr.N_params_all
    x_dict = deepcopy(nkphsolvr.x_dict)
    dx_dict = deepcopy(nkphsolvr.dx_dict)
    # initialize
    N_loss = length(mmtloss_arr)
    bhat_all0_arr = zeros(N_loss)
    bhat_all1_arr = zeros(N_loss)

    loss_val = [0.0]
    Upsilon_min_pos_eigval = [0.0]
    Upsilon_max_neg_eigval = [0.0]
    M_min_eigval = [0.0]

    dloss_val = zeros(N_params_all)
    dUpsilon_min_pos_eigval = zeros(N_params_all)
    dUpsilon_max_neg_eigval = zeros(N_params_all)
    dM_min_eigval = zeros(N_params_all)

    return NKPHTargets(x_dict, dx_dict,
        mmtvars_arr, mmtloss_arr, bhat_all0_arr, bhat_all1_arr,
        nkphmmts,
        loss_val, Upsilon_min_pos_eigval, Upsilon_max_neg_eigval, M_min_eigval,
        dloss_val, dUpsilon_min_pos_eigval, dUpsilon_max_neg_eigval, dM_min_eigval
    )
end


function compute_nkphtgts_objects!(nkphtgts::NKPHTargets)::Nothing
    # compute all target values from underlying mmt/solvr object
    # mmtvar objects
    for mmtvar in nkphtgts.mmtvars_arr
        update_mmtvar_vecs!(mmtvar, nkphtgts.nkphmmts)
    end
    # loss objects
    for mmtloss in nkphtgts.mmtloss_arr
        set_loss_val!(mmtloss, nkphtgts.nkphmmts)
    end
    # total loss, and eigenvalues
    nkphtgts.loss_val[1] = calc_loss(nkphtgts.mmtloss_arr)
    M_eigvals = nkphtgts.nkphmmts.M_eigvals
    Upsilon_eigvals = nkphtgts.nkphmmts.nkphsolvr.sdyn.Lam_diag

    _real_Upsilon_eigvals = real.(Upsilon_eigvals)
    # note: Upsilon eigenvals in lam1_diag sorted in reverse order
    # Lam_J should be smallest positive, Lam_{J+1} should be largest negative (real part)
    J = nkphtgts.nkphmmts.nkphsolvr.J
    nkphtgts.Upsilon_min_pos_eigval[1] = _real_Upsilon_eigvals[J]
    nkphtgts.Upsilon_max_neg_eigval[1] = _real_Upsilon_eigvals[J+1]

    # default eigen sort is small to large
    _M_real_eigvals = real.(M_eigvals)
    nkphtgts.M_min_eigval[1] = _M_real_eigvals[1]
    return nothing
end

#####
# derivs
function deriv_min_eigvals!(nkphtgts::NKPHTargets)::Nothing
    # derivatives of minimum eigenvalues
    nkphmmts = nkphtgts.nkphmmts
    nkphsolvr = nkphmmts.nkphsolvr
    sdyn = nkphsolvr.sdyn
    N_params, J = nkphsolvr.N_params, nkphsolvr.J
    U0_M = nkphtgts.nkphmmts.M_eigvecs
    U0_Upsilon = sdyn.W
    # deriv objects
    dUpsilon_min_pos_eigval = nkphtgts.dUpsilon_min_pos_eigval
    dUpsilon_max_neg_eigval = nkphtgts.dUpsilon_max_neg_eigval
    dM_min_eigval = nkphtgts.dM_min_eigval

    # Upsilon eigenvalue derivs only a function of Upsilon params/AR_hat
    V0_Upsilon = inv(U0_Upsilon)'
    umin_Upsilon = U0_Upsilon[:, J]
    vmin_Upsilon = V0_Upsilon[:, J]
    umax_Upsilon = U0_Upsilon[:, J+1]
    vmax_Upsilon = V0_Upsilon[:, J+1]
    for param_idx=1:nkphsolvr.N_params
        # check if param changes Upsilon
        pmap = nkphsolvr.param_model_map[param_idx]
        if pmap[:Upsilon]
            dUpsilon = @view nkphsolvr.dUpsilon[:, :, param_idx]
            dUpsilon_min_pos_eigval[param_idx] = real(
                deriv_simple_eigval(umin_Upsilon, vmin_Upsilon, dUpsilon)
            )
            dUpsilon_max_neg_eigval[param_idx] = real(
                deriv_simple_eigval(umax_Upsilon, vmax_Upsilon, dUpsilon)
            )
        end
    end
    for j_idx=1:J
        deriv_idx = _AR_idx_to_deriv_idx(nkphsolvr, j_idx)
        dUpsilon = @view nkphsolvr.dUpsilon[:, :, deriv_idx]
        dUpsilon_min_pos_eigval[deriv_idx] = real(
            deriv_simple_eigval(umin_Upsilon, vmin_Upsilon, dUpsilon)
        )
        dUpsilon_max_neg_eigval[deriv_idx] = real(
            deriv_simple_eigval(umax_Upsilon, vmax_Upsilon, dUpsilon)
        )
    end

    # M eigenvalue derivs only a function of M params
    V0_M = inv(U0_M)'
    u0_M = U0_M[:, 1]
    v0_M = V0_M[:, 1]
    # derivs wrt M (loop order corresponds to vec(M) order)
    for k_idx=1:J
        for j_idx=1:J
            deriv_idx = _m_idx_to_deriv_idx(nkphsolvr, j_idx, k_idx)
            dM_min_eigval[deriv_idx] = real(
                deriv_simple_eigval(u0_M, v0_M, j_idx, k_idx)
            )
        end
    end
    return nothing
end

# wrapper
function update_nkphtgts!(x_arr::Vector{Float64}, nkphtgts::NKPHTargets)::Nothing
    # update underlying model and compute loss
    update_nkphmmts!(x_arr, nkphtgts.nkphmmts)
    # avoid repeated computations
    if x_arr != nkphtgts.x_dict.vals
        nkphtgts.x_dict.vals[:] = x_arr
        compute_nkphtgts_objects!(nkphtgts)
    end
    return nothing
end

function deriv_nkphtgts!(nkphtgts::NKPHTargets)::Nothing
    # compute all nkphtgts derivative objects
    # compute derviative objects of underlying model
    deriv_nkphmmts!(nkphtgts.nkphmmts)
    # avoid repeated computations
    if nkphtgts.x_dict.vals != nkphtgts.dx_dict.vals
        nkphtgts.dx_dict.vals[:] = nkphtgts.x_dict.vals
        # mmtvar objects
        for mmtvar in nkphtgts.mmtvars_arr
            deriv_mmtvar!(mmtvar, nkphtgts.nkphmmts)
        end
        # loss objects
        for mmtloss in nkphtgts.mmtloss_arr
            deriv_loss_val!(mmtloss, nkphtgts.nkphmmts)
        end
        # reset to zero
        nkphtgts.dloss_val[:] .= 0.0
        deriv_loss!(nkphtgts.dloss_val, nkphtgts.mmtloss_arr)
        # derivs of Upsilon/M eigvals
        deriv_min_eigvals!(nkphtgts)
    end
    return nothing
end


########################################################################
# continuation approach to estimating model

function set_continuation_targets!(x_arr::Vector{Float64}, nkphtgts::NKPHTargets;
        valid_tol::Float64=1e-7)::Nothing
    # for continuation estimation
    update_nkphtgts!(x_arr, nkphtgts)
    _check_root_tol = maximum(abs.(nkphtgts.nkphmmts.nkphsolvr.root_vec))
    if _check_root_tol > valid_tol
        throw(DomainError(_check_root_tol, "initial point is not valid within tolerance"))
    end
    # final targets
    nkphtgts.bhat_all1_arr[:] = [mmtloss.bhat[1] for mmtloss in nkphtgts.mmtloss_arr]
    # initial model targets
    nkphtgts.bhat_all0_arr[:] = [mmtloss.b[1] for mmtloss in nkphtgts.mmtloss_arr]
    return nothing
end

function update_continuation_targets!(t::Float64, nkphtgts::NKPHTargets)::Nothing
    # update target moments in continuation estimation
    bhat_all_arr = t .* nkphtgts.bhat_all1_arr + (1-t) .* nkphtgts.bhat_all0_arr
    for (i, mmtloss) in enumerate(nkphtgts.mmtloss_arr)
        mmtloss.bhat[1] = bhat_all_arr[i]
    end
    # also reset param arrays, to force updating
    nkphtgts.x_dict.vals[:] .= 0.0
    nkphtgts.dx_dict.vals[:][:] .= 0.0
    return nothing
end
